{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "SocialAttentionDQN",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3"
  },
  "accelerator": "GPU",
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sepDWoBqdRMK",
    "colab_type": "text"
   },
   "source": [
    "# Training a DQN with social attention on `intersection-v0`\n",
    "\n",
    "## Import requirements"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Kx8X4s8krNWt",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "# Environment\n",
    "!pip install highway-env\n",
    "import gym\n",
    "import highway_env\n",
    "\n",
    "# Agent\n",
    "!pip install git+https://github.com/eleurent/rl-agents#egg=rl-agents\n",
    "\n",
    "# Visualisation utils\n",
    "import sys\n",
    "%load_ext tensorboard\n",
    "!pip install tensorboardx gym pyvirtualdisplay\n",
    "!apt-get install -y xvfb python-opengl ffmpeg\n",
    "!git clone https://github.com/eleurent/highway-env.git\n",
    "sys.path.insert(0, '/content/highway-env/scripts/')\n",
    "from utils import show_videos"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vvOEW00pdHrG",
    "colab_type": "text"
   },
   "source": [
    "## Training\n",
    "\n",
    "Prepare environment, agent, and evaluation process.\n",
    "\n",
    "We use a policy architecture based on social attention, see [[Leurent and Mercat, 2019]](https://arxiv.org/abs/1911.12250).\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "QowKW3ix45ZW",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "from rl_agents.trainer.evaluation import Evaluation\n",
    "from rl_agents.agents.common.factory import load_agent, load_environment\n",
    "\n",
    "# Get the environment and agent configurations from the rl-agents repository\n",
    "!git clone https://github.com/eleurent/rl-agents.git\n",
    "%cd /content/rl-agents/scripts/\n",
    "env_config = 'configs/IntersectionEnv/env.json'\n",
    "agent_config = 'configs/IntersectionEnv/agents/DQNAgent/ego_attention_2h.json'\n",
    "\n",
    "env = load_environment(env_config)\n",
    "agent = load_agent(agent_config, env)\n",
    "evaluation = Evaluation(env, agent, num_episodes=3000, display_env=False)\n",
    "print(f\"Ready to train {agent} on {env}\")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nqnGqW6jd1xN",
    "colab_type": "text"
   },
   "source": [
    "Run tensorboard locally to visualize training."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "q7QJY2wc4_1N",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "%tensorboard --logdir \"{evaluation.directory}\""
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BtK9dtfb0JMF",
    "colab_type": "text"
   },
   "source": [
    "Start training. This should take about an hour."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "sFVq1gFz42Eg",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "evaluation.train()"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-lNvWg42RWiw",
    "colab_type": "text"
   },
   "source": [
    "Progress can be visualised in the tensorboard cell above, which should update every 30s (or manually). You may need to click the *Fit domain to data* buttons below each graph."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VKfvu5uhzCIU",
    "colab_type": "text"
   },
   "source": [
    "## Testing\n",
    "\n",
    "Run the learned policy for a few episodes."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "gY0rpVYUtRpN",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "env = load_environment(env_config)\n",
    "env.configure({\"offscreen_rendering\": True})\n",
    "agent = load_agent(agent_config, env)\n",
    "evaluation = Evaluation(env, agent, num_episodes=3, recover=True)\n",
    "evaluation.test()\n",
    "show_videos(evaluation.run_directory)"
   ],
   "execution_count": null,
   "outputs": []
  }
 ]
}